import json


def add_lang_by_task(target_str, task, sub_task):
    if task == "summarize":
        target_str = "<en> " + target_str
    elif task == "refine":
        target_str = "<java> " + target_str
    elif task == "translate":
        if sub_task == "java-cs":
            target_str = "<c_sharp> " + target_str
        else:
            target_str = "<java> " + target_str
    elif task == "concode":
        target_str = "<java> " + target_str
    elif task == "defect":
        target_str = target_str
    return target_str


def convert_examples_to_features(item):
    example, example_index, tokenizer, args, stage = item

    if args.model_type in ["t5", "codet5"] and args.add_task_prefix:
        if args.sub_task != "none":
            source_str = "{} {}: {}".format(args.task, args.sub_task, example.source)
        else:
            source_str = "{}: {}".format(args.task, example.source)
    else:
        source_str = example.source

    source_str = source_str.replace("</s>", "<unk>")
    source_ids = tokenizer.encode(
        source_str,
        max_length=args.max_source_length,
        padding="max_length",
        truncation=True,
    )
    assert source_ids.count(tokenizer.eos_token_id) == 1
    if stage == "test":
        target_ids = []
    else:
        target_str = example.target
        if args.add_lang_ids:
            target_str = add_lang_by_task(example.target, args.task, args.sub_task)
        if args.task in ["defect", "clone"]:
            if target_str == 0:
                target_str = "false"
            elif target_str == 1:
                target_str = "true"
            else:
                raise NameError
        target_str = target_str.replace("</s>", "<unk>")
        target_ids = tokenizer.encode(
            target_str,
            max_length=args.max_target_length,
            padding="max_length",
            truncation=True,
        )
        assert target_ids.count(tokenizer.eos_token_id) == 1

    return InputFeatures(example_index, source_ids, target_ids, url=example.url)


def decoder_convert_examples_to_features(item):
    example, example_index, tokenizer, args, stage = item

    sep_token = "[SEP]"
    sep_tokid = tokenizer.encode(sep_token)
    n_sep_toks = len(sep_tokid)

    if stage == "test":
        source_str = example.source
        source_str = source_str.replace("</s>", "<unk>")
        source_len = args.max_source_length

        source_ids = []
        target_ids = []
        source_ids_decoder = (
            tokenizer.encode(source_str, max_length=source_len - n_sep_toks, truncation=True) + sep_tokid
        )
        source_ids_decoder = [tokenizer.pad_token_id] * (source_len - len(source_ids_decoder)) + source_ids_decoder

        return InputFeatures(
            example_index, source_ids, target_ids, url=example.url, source_ids_decoder=source_ids_decoder
        )
    else:
        padding_side = tokenizer.padding_side
        # padding_side = "left"
        source_str = example.source
        target_str = example.target

        source_str = source_str.replace("</s>", "<unk>")
        target_str = target_str.replace("</s>", "<unk>")

        source_len = args.max_source_length
        target_len = args.max_target_length
        max_len = source_len + target_len

        # Encode source and target string to max respective length
        source_ids = tokenizer.encode(source_str, max_length=source_len - n_sep_toks, truncation=True) + sep_tokid
        target_ids = tokenizer.encode(target_str, max_length=target_len - 1, truncation=True) + [
            tokenizer.eos_token_id
        ]

        # if len(source_ids) + len(target_ids) >= max_len:
        #     source_ids = source_ids[:-1]  # Remove 1 token from source to include the eos token

        # Create source and target ids for model
        # Target IDs mask out the source to avoid loss computation on source string
        model_source_ids = source_ids + target_ids  # + [tokenizer.eos_token_id]
        model_target_ids = [-100] * len(source_ids) + target_ids  # + [tokenizer.eos_token_id]

        # Pad to max length
        if padding_side == "left":
            model_source_ids = [tokenizer.pad_token_id] * (max_len - len(model_source_ids)) + model_source_ids
            model_target_ids = [tokenizer.pad_token_id] * (max_len - len(model_target_ids)) + model_target_ids
        elif padding_side == "right":
            model_source_ids = model_source_ids + [tokenizer.pad_token_id] * (max_len - len(model_source_ids))
            model_target_ids = model_target_ids + [tokenizer.pad_token_id] * (max_len - len(model_target_ids))

        source_ids_decoder = [tokenizer.pad_token_id] * (source_len - len(source_ids)) + source_ids
        # assert model_source_ids.count(tokenizer.eos_token_id) == 1
        # assert model_target_ids.count(tokenizer.eos_token_id) == 1

        return InputFeatures(
            example_index, model_source_ids, model_target_ids, url=example.url, source_ids_decoder=source_ids_decoder
        )


def convert_clone_examples_to_features(item):
    example, example_index, tokenizer, args = item
    if args.model_type in ["t5", "codet5"] and args.add_task_prefix:
        source_str = "{}: {}".format(args.task, example.source)
        target_str = "{}: {}".format(args.task, example.target)
    else:
        source_str = example.source
        target_str = example.target
    code1 = tokenizer.encode(
        source_str,
        max_length=args.max_source_length,
        padding="max_length",
        truncation=True,
    )
    code2 = tokenizer.encode(
        target_str,
        max_length=args.max_source_length,
        padding="max_length",
        truncation=True,
    )
    source_ids = code1 + code2
    return CloneInputFeatures(example_index, source_ids, example.label, example.url1, example.url2)


def convert_defect_examples_to_features(item):
    example, example_index, tokenizer, args = item
    if args.model_type in ["t5", "codet5"] and args.add_task_prefix:
        source_str = "{}: {}".format(args.task, example.source)
    else:
        source_str = example.source
    code = tokenizer.encode(
        source_str,
        max_length=args.max_source_length,
        padding="max_length",
        truncation=True,
    )
    return DefectInputFeatures(example_index, code, example.target)


class CloneInputFeatures(object):
    """A single training/test features for a example."""

    def __init__(self, example_id, source_ids, label, url1, url2):
        self.example_id = example_id
        self.source_ids = source_ids
        self.label = label
        self.url1 = url1
        self.url2 = url2


class DefectInputFeatures(object):
    """A single training/test features for a example."""

    def __init__(self, example_id, source_ids, label):
        self.example_id = example_id
        self.source_ids = source_ids
        self.label = label


class InputFeatures(object):
    """A single training/test features for a example."""

    def __init__(self, example_id, source_ids, target_ids, url=None, source_ids_decoder=None):
        self.example_id = example_id
        self.source_ids = source_ids
        self.target_ids = target_ids
        self.url = url
        self.source_ids_decoder = source_ids_decoder


class Example(object):
    """A single training/test example."""

    def __init__(self, idx, source, target, url=None, task="", sub_task=""):
        self.idx = idx
        self.source = source
        self.target = target
        self.url = url
        self.task = task
        self.sub_task = sub_task


class CloneExample(object):
    """A single training/test example."""

    def __init__(self, code1, code2, label, url1, url2):
        self.source = code1
        self.target = code2
        self.label = label
        self.url1 = url1
        self.url2 = url2


def read_translate_examples(filename, data_num):
    """Read examples from filename."""
    examples = []
    assert len(filename.split(",")) == 2
    src_filename = filename.split(",")[0]
    trg_filename = filename.split(",")[1]
    idx = 0
    with open(src_filename) as f1, open(trg_filename) as f2:
        for line1, line2 in zip(f1, f2):
            src = line1.strip()
            trg = line2.strip()
            examples.append(
                Example(
                    idx=idx,
                    source=src,
                    target=trg,
                )
            )
            idx += 1
            if idx == data_num:
                break
    return examples


def read_refine_examples(filename, data_num):
    """Read examples from filename."""
    examples = []
    assert len(filename.split(",")) == 2
    src_filename = filename.split(",")[0]
    trg_filename = filename.split(",")[1]
    idx = 0

    with open(src_filename) as f1, open(trg_filename) as f2:
        for line1, line2 in zip(f1, f2):
            examples.append(
                Example(
                    idx=idx,
                    source=line1.strip(),
                    target=line2.strip(),
                )
            )
            idx += 1
            if idx == data_num:
                break
    return examples


def read_concode_examples(filename, data_num):
    """Read examples from filename."""
    examples = []

    with open(filename) as f:
        for idx, line in enumerate(f):
            x = json.loads(line)
            examples.append(Example(idx=idx, source=x["nl"].strip(), target=x["code"].strip()))
            idx += 1
            if idx == data_num:
                break
    return examples


def read_summarize_examples(filename, data_num):
    """Read examples from filename."""
    examples = []
    with open(filename, encoding="utf-8") as f:
        for idx, line in enumerate(f):
            line = line.strip()
            js = json.loads(line)
            if "idx" not in js:
                js["idx"] = idx
            # code = " ".join(js["code_tokens"]).replace("\n", " ")
            # code = " ".join(code.strip().split())
            code = js["code"].strip()
            nl = " ".join(js["docstring_tokens"]).replace("\n", "")
            nl = " ".join(nl.strip().split())
            examples.append(
                Example(
                    idx=idx,
                    source=code,
                    target=nl,
                )
            )
            if idx + 1 == data_num:
                break
    return examples


def read_defect_examples(filename, data_num):
    """Read examples from filename."""
    examples = []
    with open(filename, encoding="utf-8") as f:
        for idx, line in enumerate(f):
            line = line.strip()
            js = json.loads(line)

            code = " ".join(js["func"].split())
            examples.append(Example(idx=js["idx"], source=code, target=js["target"]))
            if idx + 1 == data_num:
                break
    return examples


def read_clone_examples(filename, data_num):
    """Read examples from filename."""
    index_filename = filename
    url_to_code = {}
    with open("/".join(index_filename.split("/")[:-1]) + "/data.jsonl") as f:
        for line in f:
            line = line.strip()
            js = json.loads(line)
            code = " ".join(js["func"].split())
            url_to_code[js["idx"]] = code

    data = []
    with open(index_filename) as f:
        idx = 0
        for line in f:
            line = line.strip()
            url1, url2, label = line.split("\t")
            if url1 not in url_to_code or url2 not in url_to_code:
                continue
            if label == "0":
                label = 0
            else:
                label = 1
            data.append(CloneExample(url_to_code[url1], url_to_code[url2], label, url1, url2))
            idx += 1
            if idx == data_num:
                break
    return data


def read_mathqa_examples(filename, data_num):
    """Read examples from mathqa filename"""
    examples = []
    with open(filename, encoding="utf-8") as f:
        data = json.load(f)
        for idx, sample in enumerate(data):
            code = sample["code"].strip()
            ip = sample["text"]
            examples.append(Example(idx=sample["task_id"], source=ip, target=code))

            if idx + 1 == data_num:
                break

    return examples


def read_fixeval_examples(filename, data_num):
    examples = []
    assert len(filename.split(",")) == 2
    src_filename, tgt_filename = filename.split(",")
    idx = 0

    with open(src_filename, "r") as f1, open(tgt_filename, "r") as f2:
        for line1, line2 in zip(f1, f2):
            src = line1.strip()
            tgt = line2.strip()
            examples.append(Example(idx=idx, source=src, target=tgt))
            idx += 1

            if idx == data_num:
                break
    return examples


def read_mbpp_examples(filename, data_num):
    examples = []
    idx = 0

    with open(filename, "r") as f:
        for line in f:
            sample = json.loads(line)
            src = sample["text"].strip()
            tgt = sample["code"].strip()
            tgt = tgt.encode("utf-8").decode("unicode_escape")
            examples.append(Example(idx=idx, source=src, target=tgt))
            idx += 1

            if idx == data_num:
                break

    return examples


def read_conala_examples(filename, data_num):
    examples = []
    idx = 0

    with open(filename, "r") as f:
        for line in f:
            sample = json.loads(line)
            src = sample.get("rewritten_intent", None)
            if src is None or len(src.strip()) == 0:
                src = sample["intent"]

            tgt = sample["snippet"]
            tgt = tgt.encode("utf-8").decode("unicode_escape")
            examples.append(Example(idx=idx, source=src, target=tgt))

            idx += 1
            if idx == data_num:
                break

    return examples


def read_avatar_examples(filename, data_num):
    examples = []
    idx = 0
    assert len(filename.split(",")) == 2

    src_filename = filename.split(",")[0]
    tgt_filename = filename.split(",")[1]

    with open(src_filename, "r") as f1, open(tgt_filename, "r") as f2:
        srcdata = json.load(f1)
        tgtdata = json.load(f2)

        assert len(srcdata) == len(tgtdata)

        for idx in range(len(srcdata)):
            assert srcdata[idx]["id"] == tgtdata[idx]["id"]

            id = srcdata[idx]["id"]
            src = srcdata[idx]["code"].encode("utf-8").decode("unicode_escape")
            tgt = tgtdata[idx]["code"].encode("utf-8").decode("unicode_escape")
            examples.append(Example(idx=id, source=src, target=tgt))

            idx += 1
            if idx == data_num:
                break

    return examples
